###########################################
### Article: Where do parties interact? ###
### Task:    Structural Topic Model     ###
### File:    01_STM.R                   ###
###########################################

#--------------------------------------------------------------------------------------------------
# Description:
#
# This script loads the textual data (party press releases and tweets) and applies a Structural
# Topic Model (STM) to detect topics.
#--------------------------------------------------------------------------------------------------

#---------------------------------------------------------------------------------------------------------------------

# Get start time

start_time <- Sys.time()

#---------------------------------------------------------------------------------------------------------------------

# Load packages

library(topicmodels)
library(lda)
library(slam)
library(stm)
library(ggplot2)
library(dplyr)
library(tidytext)
library(xlsx)
library(furrr)
plan(multicore)
library(tm)
library(tidyverse)
library(Rtsne)
library(rsvd)
library(geometry)
library(NLP)
library(ldatuning) 
library(quanteda)
library(reticulate)
library(gridExtra)

#---------------------------------------------------------------------------------------------------------------------

# Load and clean corpora

corpus_pressreleases <- readRDS("Corpus_Pressreleases_AT_CH_DE.RDS")
corpus_pressreleases$text <- gsub("\u0082", " ", corpus_pressreleases$text)
corpus_pressreleases$text <- gsub("\u0091", " ", corpus_pressreleases$text)
corpus_pressreleases$text <- gsub("http:\\S+", "", corpus_pressreleases$text)
corpus_pressreleases$text <- gsub("https:\\S+", "", corpus_pressreleases$text)
corpus_pressreleases$text <- gsub("Wien \\(OTS[[:punct:]]{0,}S?K?\\)", "", corpus_pressreleases$text)
corpus_pressreleases$text <- gsub("Wien \\(OTS[[:punct:]]{0,}ÖVP-PK\\)", "", corpus_pressreleases$text)

corpus_tweets <- readRDS("Corpus_Tweets_AT_CH_DE.RDS")
corpus_tweets$text <- gsub("http:\\S+", "", corpus_tweets$text)
corpus_tweets$text <- gsub("https:\\S+", "", corpus_tweets$text)

corp <- (rbind(corpus_pressreleases[, c("text_id", "country", "party", "date", "timestamp", "type", "text")],
               corpus_tweets[, c("text_id", "country", "party", "date", "timestamp", "type", "text")]))
corp$country <- as.factor(corp$country)

#---------------------------------------------------------------------------------------------------------------------

# Preprocessing

## Text processing
processed <- textProcessor(corp$text, metadata = corp,
                           lowercase = TRUE,
                           removestopwords = TRUE,
                           removenumbers = TRUE, 
                           removepunctuation = TRUE,
                           stem = TRUE,
                           wordLengths = c(3,Inf), 
                           sparselevel = 1, 
                           language = "de",
                           verbose = TRUE,
                           onlycharacter = TRUE,
                           striphtml = FALSE,
                           customstopwords = c("dass", # "dass" with "ss" and not "ß"
                                               "amp", # remove "&" (&amp;) in UTF-8
                                               "spö", "övp", "fpö", "grüne", # party names Austria
                                               "neos", "pilz", "jetzt",
                                               "cvp", "pdc", "svp", "udc", # party names Switzerland
                                               "fdp", "plr", "bdp", "pbd",
                                               "sps", "pss", "glp",
                                               "gps", "pes", "mitte",
                                               "afd", "cdu", "csu", "cdu/csu", "cducsu", # party names Germany
                                               "fdp", "bündnis90", "die grünen",
                                               "linke", "spd"),
                           v1 = FALSE)

## Filtering out infrequent terms
out <- prepDocuments(processed$documents, processed$vocab, processed$meta, lower.thresh=10)

docs <- out$documents
vocab <- out$vocab
meta <- out$meta

## save and re-load preprocessed data
save(corp, processed, out, docs, meta, vocab,
     file = "STM_Preprocessing_Data.RDA")

load("STM_Preprocessing_Data.RDA")

#---------------------------------------------------------------------------------------------------------------------

# Identify range where optimal number of topics (k) is located

## Approach "searchK": search through different k (k = {20, 40, 60, 80, 100, 120, 140, 160, 180, 200})

set.seed(100)

system.time({
  finding_k <- searchK(documents = out$documents,
                       vocab = out$vocab,
                       K = c(20, 40, 60, 80, 100, 
                             120, 140, 160, 180, 200), # vector of potential k
                       N = floor(0.1 * length(out$documents)), # number of topics to be partially held out
                       proportion = 0.5, # default proportion of docs to be held out
                       heldout.seed = 1234, # optional
                       M = 10, # default M value for exclusivity computation
                       cores = 1, # default number of CPUs
                       prevalence =~ country, # covariates
                       max.em.its = 25, # maximum number of EM iterations
                       data = meta,
                       init.type = "Spectral",
                       verbose=TRUE
                       )
})

save(finding_k,
     file = "STM_Model.RDA")

load("STM_Model.RDA")

## Diagnose models and get range where optimal k is located

### semantic coherence
semcoh <- as.numeric(finding_k$results$semcoh)
names(semcoh) <- finding_k$results$K
df_semcoh <- data.frame(k = as.numeric(names(semcoh)), semcoh = semcoh) %>%
  arrange(k)

### exclusivity
exclus <- as.numeric(finding_k$results$exclus)
names(exclus) <- finding_k$results$K
df_exclus <- data.frame(k = as.numeric(names(exclus)), exclus = exclus) %>%
  arrange(k)

### held-out likelihood
heldout <- as.numeric(finding_k$results$heldout)
names(heldout) <- finding_k$results$K
df_heldout <- data.frame(k = as.numeric(names(heldout)), heldout = heldout) %>%
  arrange(k)

### residuals
residual <- as.numeric(finding_k$results$residual)
names(residual) <- finding_k$results$K
df_residual <- data.frame(k = as.numeric(names(residual)), residual = residual) %>%
  arrange(k)

### plots
grid.arrange(
  
  ggplot(df_semcoh, aes(x = k, y = semcoh)) + 
    geom_point() +
    geom_line() +
    theme_bw() +
    labs(x = "Number of topics", y = "Semantic coherence"),
  
  ggplot(df_exclus, aes(x = k, y = exclus)) + 
    geom_point() +
    geom_line() +
    theme_bw() +
    labs(x = "Number of topics", y = "Exclusivity"),
  
  ggplot(df_heldout, aes(x = k, y = heldout)) + 
    geom_point() +
    geom_line() +
    theme_bw() +
    labs(x = "Number of topics", y = "Held-out likelihood"),
  
  ggplot(df_residual, aes(x = k, y = residual)) +
    geom_point() +
    geom_line() +
    theme_bw() +
    labs(x = "Number of topics", y = "Residuals"),
  
  ncol = 2
  
)

range_k <- 80:100
paste0("Range where the optimal number of topics (k) is located: ", min(range_k), "-", max(range_k), ".")

#---------------------------------------------------------------------------------------------------------------------

# Run final topic models

system.time({
  stm_80 <- stm(documents = out$documents,
                vocab = out$vocab,
                K = 80,
                seed = 1234,
                prevalence =~ country,
                max.em.its = 25, # maximum number of EM iterations
                data = out$meta,
                init.type = "Spectral",
                verbose = FALSE
                )
})

system.time({
  stm_85 <- stm(documents = out$documents,
                vocab = out$vocab,
                K = 85,
                seed = 1234,
                prevalence =~ country,
                max.em.its = 25, # maximum number of EM iterations
                data = out$meta,
                init.type = "Spectral",
                verbose = FALSE
                )
})

system.time({
  stm_90 <- stm(documents = out$documents,
                vocab = out$vocab,
                K = 90,
                seed = 1234,
                prevalence =~ country,
                max.em.its = 25, # maximum number of EM iterations
                data = out$meta,
                init.type = "Spectral",
                verbose = FALSE
                )
})

system.time({
  stm_95 <- stm(documents = out$documents,
                vocab = out$vocab,
                K = 95,
                seed = 1234,
                prevalence =~ country,
                max.em.its = 25, # maximum number of EM iterations
                data = out$meta,
                init.type = "Spectral",
                verbose = FALSE
                )
})

system.time({
  stm_100 <- stm(documents = out$documents,
                 vocab = out$vocab,
                 K = 100,
                 seed = 1234,
                 prevalence =~ country,
                 max.em.its = 25, # maximum number of EM iterations
                 data = out$meta,
                 init.type = "Spectral",
                 verbose = FALSE
                 )
})


save(finding_k,
     stm_80, stm_85, stm_90, stm_95, stm_100, 
     file = "STM_Model.RDA")

load("STM_Model.RDA")

#---------------------------------------------------------------------------------------------------------------------

# Evaluation topic models

## call validation script
source("STM_Validation.R")

paste0("The optimal STM model (final_stm) has k = ", 
       final_stm$settings$dim$K, 
       " as the model and topic intrusion tests show high model precision (", 
       round(summary_85$rater_precision, digits = 2),
       ") and high mean TLO (",
       round(mean(summary_85$tlo), digits = 2),
       ") for this model compared to the other models under investigation (k = 80, k = 90, k = 95 or k = 100).")

#---------------------------------------------------------------------------------------------------------------------

# Prepare data set for topic labeling

df_topics_unlabeled <- cbind(data.frame(labelTopics(final_stm, n = 10)[5]),
                             data.frame(labelTopics(final_stm, n = 10)[1]),
                             topic_label = NA,
                             policy_related = NA,
                             note = NA)

colnames(df_topics_unlabeled)[1] <- "topic_number"

write.xlsx(df_topics_unlabeled, "STM_Topics_unlabeled.xlsx", row.names = FALSE)

#---------------------------------------------------------------------------------------------------------------------

# Topic labeling

## Manual topic labeling based on words with highest probability within each topic

## Load labeled topics
df_topics_labeled <- read.xlsx("STM_Topics_labeled.xlsx", sheetIndex = 1)

## Topic overview
sort(table(df_topics_labeled$topic_label), decreasing = TRUE)
sort(unique(df_topics_labeled$topic_label))

#---------------------------------------------------------------------------------------------------------------------

# Assign topics to documents

## Topic probability each topic from topic model
stm_res <- final_stm$theta

## neutralise "Legislation and initiatives" topics
idx_na <- which(df_topics_labeled$topic_label == "Legislation and initiatives")
stm_res[, idx_na] <- NA

## neutralise "Elections" topics
idx_na <- which(df_topics_labeled$topic_label == "Elections")
stm_res[, idx_na] <- NA

## neutralise "NA" topics
idx_na <- which(df_topics_labeled$topic_label == "NA")
stm_res[, idx_na] <- NA

## Topic probabilities for each unique labelled policy-related topic
stm_theta <- data.frame(id = 1:nrow(stm_res))

for(i in 1:length(unique(df_topics_labeled$topic_label))) {
  
  print(i)

  tmp <- sort(unique(df_topics_labeled$topic_label))[i]
  idx <- which(df_topics_labeled$topic_label == tmp)

  theta_tmp <- as.data.frame(stm_res) %>%
    select(idx) %>%
    rowwise() %>%
    summarize(sum(c_across()))

  colnames(theta_tmp) <- tmp
  stm_theta <- cbind(stm_theta, theta_tmp)

}

stm_theta$id <- NULL

# custom penalty for probabilities of overrepresented categories in tweets (i.e. Environment; Housing and infrastructure)
tweet_idx <- which(meta$type == "Tweet")

stm_theta[tweet_idx, "Environment"] <- stm_theta[tweet_idx, "Environment"] - 0.05
stm_theta[tweet_idx, "Housing and infrastructure"] <- stm_theta[tweet_idx, "Housing and infrastructure"] - 0.05 

# final data set
df_res <- left_join(corp,
                    data.frame(meta,
                               max_theta = apply(stm_theta, 1, max, na.rm = T),
                               topic_number = apply(stm_theta, 1, which.max)))

df_topic_label <- data.frame(topic_number = 1:ncol(stm_theta),
                             topic_label = colnames(stm_theta))

df_res <- left_join(df_res, df_topic_label[, c("topic_number", "topic_label")])

## neutralise TV event information as sometimes classified as media policy topic
df_res$topic_label[str_detect(df_res$text, "T?V?-?[[:space:]]?Tipp:?") == TRUE] <- "Event information"

## exclude short tweets
idx_length <- which(str_count(df_res$text, "\\W+") > 5)
df_res <- df_res[idx_length,]

## remove french language tweets
idx_french <- corpus_tweets$text_id[which(corpus_tweets$lang == "fr")]
df_res <- subset(df_res, !(text_id %in% idx_french))

## assign "NA" to tweets with low maximum theta (maximum theta < 0.1)
df_res$topic_label[df_res$type == "Tweet" & df_res$max_theta < 0.1] <- "NA"

## check topic frequency by text type
table(df_res$topic_label, df_res$type)

#---------------------------------------------------------------------------------------------------------------------

# Save results

saveRDS(df_res, "STM_Result.RDS")

#---------------------------------------------------------------------------------------------------------------------

# Get run time of script

end_time <- Sys.time()

duration <- end_time - start_time
duration

#---------------------------------------------------------------------------------------------------------------------
